# 1_run_embedding_optimized.py

import json
import time
import pickle
import re  # Regular expressions
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
API_KEY = ""  # Removed for privacy
BASE_URL = ""  # Removed for privacy
EMBEDDING_MODEL = "text-embedding-3-small"

# --- File Paths ---
INPUT_FILE_PATH = ""  # Removed for privacy
DATA_KEY = "min_word_prompt"
OUTPUT_FILE_PATH = ""  # Removed for privacy


def load_texts_from_jsonl(file_path: str, key: str) -> list[str]:
    """
    Load text data from a JSON Lines (.jsonl) file.
    """
    print(f"Step 1/4: Loading data from file {file_path}...")
    texts = []
    skipped_count = 0
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())
                    text = data.get(key)
                    if text and text != "NoRefuse":
                        texts.append(text)
                    elif text == "NoRefuse":
                        skipped_count += 1
                except json.JSONDecodeError:
                    print(f"Warning: Skipping an unparsable JSON line: {line.strip()}")
    except FileNotFoundError:
        print(f"Error: File {file_path} not found")
        return []
    print(f"Data loaded! Found {len(texts)} valid texts, skipped {skipped_count} 'NoRefuse' records.")
    return texts


def clean_texts_for_embedding(texts: list[str]) -> list[str]:
    """
    Optimized cleaning function.
    Prepares text for embedding by replacing punctuation with spaces
    and merging multiple consecutive spaces into one.
    """
    print(f"Step 2/4: Cleaning text data, converting punctuation to spaces...")
    cleaned_texts = []
    # Regex pattern to match all characters that are not Chinese, letters, or numbers
    punctuation_pattern = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9]')
    # Regex pattern to match one or more consecutive whitespace characters
    whitespace_pattern = re.compile(r'\s+')

    for text in texts:
        # Step 1: Replace punctuation with a space
        text_with_spaces = punctuation_pattern.sub(' ', text)

        # Step 2: Collapse multiple spaces into one and strip leading/trailing spaces
        normalized_text = whitespace_pattern.sub(' ', text_with_spaces).strip()

        cleaned_texts.append(normalized_text)

    print("Text cleaning complete.")
    return cleaned_texts


def get_embeddings(texts: list[str], batch_size: int = 100) -> list[list[float]]:
    """Batch get embeddings for a list of texts."""
    if not texts:
        return []
    print(f"Step 3/4: Generating embeddings using model '{EMBEDDING_MODEL}'...")
    client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
    all_embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Getting Embeddings"):
        batch = texts[i:i + batch_size]
        try:
            response = client.embeddings.create(input=batch, model=EMBEDDING_MODEL)
            embeddings = [item.embedding for item in response.data]
            all_embeddings.extend(embeddings)
        except Exception as e:
            print(f"API error occurred while processing batch {i // batch_size + 1}: {e}")
        time.sleep(0.5)
    print(f"Embedding complete! Successfully obtained {len(all_embeddings)}/{len(texts)} embeddings.")
    return all_embeddings


def save_data(file_path: str, texts: list[str], vectors: list[list[float]]):
    """Save texts and vectors to a pickle file."""
    if not texts or not vectors or len(texts) != len(vectors):
        print("Error: Text or vector data is empty, or counts do not match. Cannot save.")
        return
    print(f"Step 4/4: Saving raw texts and embeddings to {file_path}...")
    data_to_save = {
        "texts": texts,
        "vectors": vectors
    }
    with open(file_path, 'wb') as f:
        pickle.dump(data_to_save, f)
    print("Data saved successfully!")


if __name__ == "__main__":
    source_texts = load_texts_from_jsonl(INPUT_FILE_PATH, DATA_KEY)
    if source_texts:
        cleaned_texts_for_embedding = clean_texts_for_embedding(source_texts)
        embedding_vectors = get_embeddings(cleaned_texts_for_embedding)
        if embedding_vectors:
            save_data(OUTPUT_FILE_PATH, source_texts, embedding_vectors)
    print("\nEmbedding pipeline completed.")
